
import torch
import util
import objective_grad as gd
import matplotlib.pyplot as plt



def FRLC_iteration(C, a=None, b=None, A=None, B=None, tau=75, gamma=90, r = 10, r2=None, max_iter=200, device='cpu', dtype=torch.float64, \
                  semiRelaxedLeft=False, semiRelaxedRight=False, Wasserstein=True, \
                 printCost=True, returnFull=True, FGW=False, alpha=0.5, unbalanced=False, \
                  initialization='Full', init_args = None, full_grad=True, \
                   convergence_criterion=True, tol=1e-5, min_iter = 25, \
                   min_iterGW = 500, max_iterGW = 1000, \
                   max_inneriters_balanced= 300, max_inneriters_relaxed=50, \
                  diagonalize_return=False):


    
    '''
    FRLC: Factor Relaxation with Latent Coupling
    ------Parameters------
    C: torch.tensor (N1 x N2)
        A matrix of pairwise feature distances in space X and space Y (inter-space).
    a: torch.tensor (N1)
        A vector representing marginal one.
    b: torch.tensor (N2)
        A vector representing marginal two.
    A: torch.tensor (N1 x N1)
        A matrix of pairwise distances between points in metric space X.
    B: torch.tensor (N2 x N2)
        A matrix of pairwise distances between points in metric space Y.
    tau: float (> 0)
        A scalar which controls the regularity of the inner marginal update path.
    gamma: float (> 0)
        The mirror descent step-size, a scalar which controls the scaling of gradients
        before being exponentiated into Sinkhorn kernels.
    r: int (> 1)
        A non-negative integer rank, controlling the rank of the FRLC learned OT coupling. 
    max_iter: int
        The maximal number of iterations FRLC will run until convergence.
    device: str
        The device (i.e. 'cpu' or 'cuda') which FRLC runs on.
    dtype: dtype
        The datatype all tensors are stored on (naturally there is a space-accuracy
        tradeoff for low-rank between 32 and 64 bit).
    semiRelaxedLeft: bool
        True if running the left-marginal relaxed low-rank algorithm.
    semiRelaxedRight: bool
        True if running the right-marginal relaxed low-rank algorithm.
    Wasserstein: bool
        True if using the Wasserstein loss <C, P>_F as the objective cost,
        else runs GW if FGW false and FGW if GW true.
    printCost: bool
        True if printing the value of the objective cost at each iteration.
        This is very expensive for large datasets.
    returnFull: bool
        True if returning P_r = Q Lambda R.T, else returns iterates (Q, R, T).
    FGW: bool
        True if running the Fused-Gromov Wasserstein problem, and otherwise false.
    alpha: float
        A balance parameter between the Wasserstein term and
        the Gromov-Wasserstein term of the objective.
    unbalanced: bool
        True if running the unbalanced problem;
        if semiRelaxedLeft/Right and unbalanced False (default) then running the balanced problem.
    initialization: str, 'Full' or 'Rank-2'
        'Full' if sub-couplings initialized to be full-rank, if 'Rank-2' set to a rank-2 initialization.
        We advise setting this to be 'Full'.
    init_args: tuple of 3-tensors
        A tuple of (Q0, R0, T0) for tuple[i] of type tensor
    full_grad: bool
        If True, evaluates gradient with rank-1 perturbations.
        Else if False, omits perturbation terms.
    convergence_criterion: bool
        If True, use the convergence criterion.
        Else if False, default to running up to max_iters.
    tol: float
        Tolerance used for established when convergence is reached.
    min_iter: int
        The minimum iterations for the algorithm to run for in the Wasserstein case.
    min_iterGW: int
        The minimum number of iterations to run for in the GW case.
    max_iterGW: int
        The maximum number of iterations to run for in the GW case.
    max_inneriters_balanced: int
        The maximum number of inner iterations for the Sinkhorn loop.
    max_inneriters_relaxed: int
        The maximum number of inner iterations for the relaxed and semi-relaxed loops.
    diagonalize_return: bool
        If True, diagonalize the LC-factorization to the form of Forrow et al 2019.
        Else if False, return the LC-factorization.
    '''


    
    N1, N2 = C.size(dim=0), C.size(dim=1)
    k = 0
    stationarity_gap = torch.inf
    
    one_N1 = torch.ones((N1), device=device, dtype=dtype)
    one_N2 = torch.ones((N2), device=device, dtype=dtype)
    
    if a is None:
        a = one_N1 / N1
    if b is None:
        b = one_N2 / N2
    if r2 is None:
        r2 = r

    one_r = torch.ones((r), device=device, dtype=dtype)
    one_r2 = torch.ones((r2), device=device, dtype=dtype)
    
    # Initialize inner marginals to uniform; 
    # generalized to be of differing dimensions to account for non-square latent-coupling.
    gQ = (1/r)*one_r
    gR = (1/r2)*one_r2
    
    full_rank = True if initialization == 'Full' else False
    
    if initialization == 'Full':
        full_rank = True
    elif initialization == 'Rank-2':
        full_rank = False
    else:
        full_rank = True
        print('Initialization must be either "Full" or "Rank-2", defaulting to "Full".')
        
    if init_args is None:
        Q, R, T, Lambda = util.initialize_couplings(a, b, gQ, gR, \
                                                    gamma, full_rank=full_rank, \
                                                device=device, dtype=dtype, \
                                                    max_iter = max_inneriters_balanced)
    else:
        Q, R, T = init_args
        Lambda = torch.diag(1/ (Q.T @ one_N1)) @ T @ torch.diag(1/ (R.T @ one_N2))

    if Wasserstein is False:
        min_iter = min_iterGW
        max_iter = max_iterGW

    '''
    Preparing main loop.
    '''
    errs = []
    grad = torch.inf
    gamma_k = gamma
    Q_prev, R_prev, T_prev = None, None, None
    
    while (k < max_iter and (not convergence_criterion or \
                       (k < min_iter or util.Delta((Q, R, T), (Q_prev, R_prev, T_prev), gamma_k) > tol))):
        
        if convergence_criterion:
            # Set previous iterates to evaluate convergence at the next round
            Q_prev, R_prev, T_prev = Q, R, T
        
        if k % 25 == 0:
            print(f'Iteration: {k}')
        
        gradQ, gradR, gamma_k = gd.compute_grad_A(C, Q, R, Lambda, gamma, semiRelaxedLeft, \
                                               semiRelaxedRight, device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha, \
                                                  unbalanced=unbalanced, full_grad=full_grad)
        if semiRelaxedLeft:
            xi1 = Q * torch.exp( -gamma_k * gradQ )
            xi2 = R * torch.exp( -gamma_k * gradR )
            
            _u, _v = util.semi_project_Balanced(xi2, b, gR, N2, gR.shape[0], \
                                                gamma_k, tau, device=device, max_iter = max_inneriters_relaxed)
            R = torch.diag(_u) @ xi2 @ torch.diag(_v)
            
            u, v = util.project_Unbalanced(xi1, a, gQ, N1, gQ.shape[0], \
                                           gamma_k, tau, device=device, max_iter = max_inneriters_relaxed)
            
            Q = torch.diag(u) @ xi1 @ torch.diag(v)
            
            gQ, gR = Q.T @ one_N1, R.T @ one_N2
            gradT, gamma_T = gd.compute_grad_B(C, Q, R, Lambda, gQ, gR, \
                                               gamma, device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha, full_grad=full_grad)
            xi3 = T*torch.exp(- gamma_T * gradT )
            # Lambda = diag(gQ)^-1 T diag(gR)^-1 form
            u, v = util.Sinkhorn(xi3, gQ, gR, r, r, gamma_T, device=device, \
                                max_iter = max_inneriters_balanced)
            v = gR / (xi3.T @ u)
        elif semiRelaxedRight:
            xi1 = Q * torch.exp(- gamma_k * gradQ )
            xi2 = R * torch.exp(- gamma_k * gradR )
            
            u, v = util.semi_project_Balanced(xi1, a, gQ, N1, gQ.shape[0], \
                                              gamma_k, tau, device=device, \
                                             max_iter = max_inneriters_relaxed)
            Q = torch.diag(u) @ xi1 @ torch.diag(v)
            
            _u, _v = util.project_Unbalanced(xi2, b, gR, N2, gR.shape[0], \
                                             gamma_k, tau, device=device, \
                                            max_iter = max_inneriters_relaxed)
            
            R = torch.diag(_u) @ xi2 @ torch.diag(_v)
            gQ, gR = Q.T @ one_N1, R.T @ one_N2
            gradT, gamma_T = gd.compute_grad_B(C, Q, R, Lambda, gQ, gR, \
                                               gamma, device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha)
            xi3 = T*torch.exp(- gamma_T * gradT )
            
            # Lambda = diag(gQ)^-1 T diag(gR)^-1 form
            u, v = util.Sinkhorn(xi3, gQ, gR, gQ.shape[0], gR.shape[0], \
                                 gamma_T, device=device, max_iter = max_inneriters_balanced)
            u = gQ / (xi3 @ v)
        elif unbalanced:
            xi1 = Q * torch.exp(- gamma_k * gradQ )
            xi2 = R * torch.exp(- gamma_k * gradR )
            
            u, v = util.project_Unbalanced(xi1, a, gQ, N1, gQ.shape[0], \
                                           gamma_k, tau, device=device, \
                                          max_iter = max_inneriters_relaxed)
            Q = torch.diag(u) @ xi1 @ torch.diag(v)
            
            _u, _v = util.project_Unbalanced(xi2, b, gR, N2, gR.shape[0], \
                                             gamma_k, tau, device=device, \
                                            max_iter = max_inneriters_relaxed)
            
            R = torch.diag(_u) @ xi2 @ torch.diag(_v)
            
            gQ, gR = Q.T @ one_N1, R.T @ one_N2
            
            gradT, gamma_T = gd.compute_grad_B(C, Q, R, Lambda, gQ, gR, gamma, \
                                               device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha)
            xi3 = T*torch.exp(- gamma_T * gradT )
            
            # Lambda = diag(gQ)^-1 T diag(gR)^-1 form
            u, v = util.Sinkhorn(xi3, gQ, gR, gQ.shape[0], gR.shape[0], \
                                 gamma_T, device=device, max_iter = max_inneriters_balanced)
            
        else:
            # Balanced
            xi1 = Q * torch.exp(- gamma_k * gradQ )
            xi2 = R * torch.exp(- gamma_k * gradR )
            
            u, v = util.semi_project_Balanced(xi1, a, gQ, N1, gQ.shape[0], \
                                              gamma_k, tau, device=device, \
                                             max_iter = max_inneriters_relaxed)
            Q = torch.diag(u) @ xi1 @ torch.diag(v)
            
            _u, _v = util.semi_project_Balanced(xi2, b, gR, N2, gR.shape[0], \
                                                gamma_k, tau, device=device, \
                                               max_iter = max_inneriters_relaxed)
            R = torch.diag(_u) @ xi2 @ torch.diag(_v)
            
            gQ, gR = Q.T @ one_N1, R.T @ one_N2
            gradT, gamma_T = gd.compute_grad_B(C, Q, R, Lambda, gQ, gR, gamma, \
                                               device, Wasserstein=Wasserstein, \
                                               A=A, B=B, FGW=FGW, alpha=alpha)
            xi3 = T * torch.exp(- gamma_T * gradT )
            # Lambda = diag(gQ)^-1 T diag(gR)^-1 form
            u, v = util.Sinkhorn(xi3, gQ, gR, gQ.shape[0], gR.shape[0], \
                                 gamma_T, device=device, max_iter = max_inneriters_balanced)
        
        # Construct latent transition matrix
        T = torch.diag(u) @ xi3 @ torch.diag(v)
        # Inner latent transition-inverse matrix
        Lambda = torch.diag(1/gQ) @ T @ torch.diag(1/gR)
        
        if printCost:
            if Wasserstein:
                P = Q @ Lambda @ R.T
                cost = torch.sum(C * P)
            else:
                P = Q @ Lambda @ R.T
                M1 = Q.T @ A**2 @ Q
                M2 = R.T @ B**2 @ R
                cost = one_r.T @ M1 @ one_r + one_r.T @ M2 @ one_r -2*torch.trace((A @ P @ B).T @ P)
                #cost = one_N2.T @ P.T @ A**2 @ P @ one_N2 + one_N1.T @ P @ B**2 @ P.T @ one_N1 - 2*torch.trace((A @ P @ B).T @ P)
                if FGW:
                    cost = (1-alpha)*torch.sum(C * P) + alpha*cost
            errs.append(cost.cpu())
            
        k+=1
    if printCost:
        ''' 
        Plotting OT objective value across iterations.
        '''
        plt.plot(range(len(errs)), errs)
        plt.xlabel('Iterations')
        plt.ylabel('OT-Cost')
        plt.show()
        '''
        Plotting latent coupling.
        '''
        plt.imshow(T.cpu())
        plt.show()
        
    if diagonalize_return:
        '''
        Diagonalize return to factorization of (Forrow 2019)
        '''
        Q = Q @ torch.diag(1 / gQ) @ T
        gR = R.T @ one_N2
        T = torch.diag(gR)
    
    if returnFull:
        P = Q @ Lambda @ R.T
        return P, errs
    else:
        return Q, R, T, errs






